package nsalt.fetch.branch

import chisel3._
import chisel3.util._

import nsalt._
import chiselFv._

// Pattern History Table / PHT responds for determining whether to
// take branch or not, based on the most recent two records.
// 
// PHT is embedded into original BPU, and excerpted out. 
// https://github.com/OSCPU/NutShell/blob/fd86beadfc47f52973270ce6109edebd2a30363b/src/main/scala/nutcore/frontend/BPU.scala#L112

// Docs can be found at:
// https://oscpu.gitbook.io/nutshell/gong-neng-bu-jian-she-ji-xi-jie/bpu#geng-xin-ji-zhi

class PatternHistory(val entryCount: Int, val addr: EntryAddr) extends Module with Formal {

  val io = IO(new Bundle {

    val in = new PredictPort()

    val out = Output(Bool())
    // val out = Output(UInt(2.W))

    // for debugging output
    val feedback = new FeedbackPort()

  })
 
  // Mem: Async Read (curr-cycle), Sync Write (next-cycle)
  // Notes on RAW Hazard issue: Mem is synced to be register file/bank. Read through wire, and Update 
  // at positive edge. Thus read always after write in the same cycle.
  val mem = Mem(entryCount, UInt(2.W))
 
  val currCount = RegNext(mem.read(addr.getIdx(io.in.feed.pc)))
  val feedback  = RegNext(io.in.feed)

  val taken = feedback.branchTaken
  val nextCount = Mux(taken, currCount + 1.U, currCount - 1.U)

  val shouldUpdate = (taken && (currCount =/= "b11".U)) || (!taken && (currCount =/= "b00".U))


  when (feedback.valid && ArithLogicOperType.isCond(feedback.operType) ) {
    
    when (shouldUpdate) {
      mem.write(addr.getIdx(feedback.pc), nextCount)
    }

  }

  // ========================================================================================
  // FORMAL VERIFICATION 
  // 
  // If the feedback request arrives, and contains branching instruction, then the history table will
  // be updated.

  // This means, if feedback arrives 2 cycles ago, the memory will be updated at the past 1 cycle. If
  // there is no new feedback arriving, we will be able to read the result at the current cycle.

  val memReadPrev = mem.read(addr.getIdx(RegNext(feedback).pc))

  when (RegNext(feedback).valid && ArithLogicOperType.isCond(RegNext(feedback).operType)){
    when (!RegNext(feedback).branchTaken && (RegNext(currCount) =/= "b11".U)) {
      assert(memReadPrev === RegNext(currCount) + 1.U)
    }
  }

  io.out := RegEnable(mem.read(addr.getIdx(io.in.pc.bits))(1), io.in.pc.valid)
  
  io.feedback <> feedback
}


